{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from typing import List, Optional, Tuple, Union\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint\n",
    "from torch import nn\n",
    "import os\n",
    "from torch.utils.data import IterableDataset, Dataset\n",
    "import json\n",
    "import numpy as np\n",
    "from transformers import  PreTrainedModel\n",
    "from transformers.modeling_outputs import CausalLMOutputWithPast\n",
    "from transformers import PretrainedConfig\n",
    "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rmsnorm\n",
    "# 计算layernorm时是减去均值除以标准差，然后乘以权重。而rmsnorm没有减去均值，是直接除以均方根，然后乘以权重。\n",
    "class RMSNorm(nn.Module):\n",
    "    def __init__(self, hidden_size, eps=1e-6):\n",
    "        \n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.ones(hidden_size))\n",
    "        self.variance_epsilon = eps\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        hidden_states = hidden_states.float()\n",
    "        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n",
    "        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n",
    "        return self.weight * hidden_states.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 旋转操作\n",
    "def rotate_half(x):\n",
    "    x1, x2 = x.chunk(2, dim=-1)\n",
    "    return torch.cat((-x2, x1), dim=-1)\n",
    "\n",
    "# 应用位置编码\n",
    "def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):\n",
    "    \n",
    "    cos = cos.unsqueeze(unsqueeze_dim) # (1, seq_len, 1, dim)\n",
    "    sin = sin.unsqueeze(unsqueeze_dim) # (1, seq_len, 1, dim)\n",
    "   \n",
    "    q_embed = (q*cos) + (rotate_half(q)*sin)  # (batch_size, seq_len, head_num, dim) * (1, seq_len, 1, dim) = (batch_size, seq_len, head_num, dim) 广播\n",
    "    k_embed = (k*cos) + (rotate_half(k)*sin)  # (batch_size, seq_len, head_num, dim) * (1, seq_len, 1, dim) = = (batch_size, seq_len, head_num, dim) 广播\n",
    "    \n",
    "    return q_embed, k_embed\n",
    "\n",
    "class RotaryEmbedding(nn.Module):\n",
    "    def __init__(self, dim, max_seq_len=2048):\n",
    "        super(RotaryEmbedding, self).__init__()\n",
    "        self.dim = dim\n",
    "        self.max_seq_len = max_seq_len\n",
    "        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))  # 形状(dim/2)\n",
    "        t = torch.arange(max_seq_len).float().unsqueeze(1)  # 形状(max_seq_len, 1)\n",
    "        freqs = t @ inv_freq.unsqueeze(0)  #(max_seq_len, 1)*(1, dim/2) = (max_seq_len, dim/2)\n",
    "        freqs = torch.cat((freqs, freqs), dim=-1)  # (max_seq_len, dim)\n",
    "        \n",
    "        self.register_buffer(\"cos_cached\", freqs.cos())\n",
    "        self.register_buffer(\"sin_cached\", freqs.sin())\n",
    "        \n",
    "    def forward(self, q, k):\n",
    "        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)  # (1, seq_len, dim)\n",
    "        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)  # (1, seq_len, dim)\n",
    "        return apply_rotate_pos_emb(q, k, cos, sin)\n",
    "    \n",
    "    \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对于group query，一个q共享多个k，v，需要对k，v进行复制\n",
    "def repeat_kv(hidden_states, n_rep):\n",
    "    \n",
    "    batch, slen, num_key_value_heads, head_dim = hidden_states.shape\n",
    "    if n_rep == 1:\n",
    "        return hidden_states\n",
    "    hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)\n",
    "    return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 注意力机制\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.dropout = config.dropout\n",
    "        self.hidden_size = config.hidden_size\n",
    "        self.num_heads = config.num_attention_heads\n",
    "        self.head_dim = getattr(config, \"head_dim\", self.hidden_size // self.num_heads)\n",
    "        self.num_key_value_heads = config.num_key_value_heads\n",
    "        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n",
    "        self.k_cache, self.v_cache = None, None\n",
    "        self.is_causal = True\n",
    "        self.flash_attn = self.config.flash_attn\n",
    "\n",
    "        # 初始化q,k,v变换矩阵\n",
    "        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)\n",
    "        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)\n",
    "        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)\n",
    "        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)\n",
    "        self.residual_dropout = nn.Dropout(self.dropout)\n",
    "        self.attention_dropout = nn.Dropout(self.dropout)\n",
    "        self.rotary_emb = RotaryEmbedding(self.head_dim)\n",
    "        \n",
    "    def forward(self, hidden_states, use_kv_cache=False):\n",
    "        b, s = hidden_states.shape[:2]\n",
    "    \n",
    "        if use_kv_cache and self.eval():\n",
    "            if self.k_cache is None or self.k_cache.shape[1] != s-1:\n",
    "                q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)\n",
    "            else:\n",
    "                # 获取最新生成的token\n",
    "                token = hidden_states[:, -1:, :] # 形状(b, 1, dim)\n",
    "                q = torch.cat((torch.zeros_like(hidden_states[:, :-1, :]), self.q_proj(token)), dim=1) \n",
    "                # 新的k,v和之前已经生成的进行拼接\n",
    "                k = torch.cat((self.k_cache, self.k_proj(token)), dim=1)\n",
    "                v = torch.cat((self.v_cache, self.v_proj(token)), dim=1)\n",
    "            # 更新cache\n",
    "            self.k_cache, self.v_cache = k, v\n",
    "            \n",
    "        else:\n",
    "            q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)\n",
    "            \n",
    "        q = q.view(b, s, self.num_heads, self.head_dim)\n",
    "        k = k.view(b, s, self.num_key_value_heads, self.head_dim)\n",
    "        v = v.view(b, s, self.num_key_value_heads, self.head_dim)\n",
    "        \n",
    "        q, k = self.rotary_emb(q, k)\n",
    "        \n",
    "        k = repeat_kv(k, self.num_key_value_groups)\n",
    "        v = repeat_kv(v, self.num_key_value_groups)\n",
    "        \n",
    "        q = q.transpose(1, 2) # b, self.num_heads, s, self.head_dim\n",
    "        k = k.transpose(1, 2) # b, self.num_heads, s, self.head_dim\n",
    "        v = v.transpose(1, 2) # b, self.num_heads, s, self.head_dim\n",
    "        \n",
    "        if self.flash_attn:\n",
    "        \n",
    "            # q*k转置，（b, self.num_heads, s, self.head_dim）* (b, self.num_heads, self.head_dim，s) = （b, self.num_heads, s, s）\n",
    "            # q*k/sqrt(self.head_dim)*v  （b, self.num_heads, s, s）* (b, self.num_heads, s, self.head_dim) = b, self.num_heads, s, self.head_dim\n",
    "            output = F.scaled_dot_product_attention(q, k, v, attn_mask=None, \n",
    "                                                    dropout_p=self.dropout if self.training else 0.0, \n",
    "                                                    is_causal=self.is_causal) \n",
    "        else:\n",
    "            mask = torch.full((1, 1, self.config.max_seq_len, self.config.max_seq_len), float(\"-inf\"))  # 初始化掩码\n",
    "            mask = torch.triu(mask, diagonal=1)  # 生成上三角掩码\n",
    "            scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)  # 计算注意力分数\n",
    "            scores = scores + self.mask[:, :, :s, :s]  # 应用掩码\n",
    "            scores = F.softmax(scores.float(), dim=-1).type_as(q)  # 计算 softmax\n",
    "            scores = self.attention_dropout(scores)  # 应用注意力 dropout\n",
    "            output = torch.matmul(scores, v)  # 计算输出\n",
    "        \n",
    "        output = output.transpose(1, 2).contiguous().view(b, s, -1) # b, s, self.hidden_size\n",
    "        \n",
    "        output = self.o_proj(output)\n",
    "        output = self.residual_dropout(output)\n",
    "        return output\n",
    "                \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.hidden_size = config.hidden_size\n",
    "        self.intermediate_size = config.intermediate_size\n",
    "        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)\n",
    "        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)\n",
    "        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))\n",
    "        return down_proj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self, config, layer_idx):\n",
    "        super().__init__()\n",
    "        self.hidden_size = config.hidden_size\n",
    "        self.self_attn = Attention(config)\n",
    "        self.mlp = MLP(config)\n",
    "        self.input_layernorm = RMSNorm(config.hidden_size)\n",
    "        self.post_attention_layernorm = RMSNorm(config.hidden_size)\n",
    "        self.layer_idx = layer_idx\n",
    "    def forward(\n",
    "        self,\n",
    "        hidden_states,\n",
    "        use_kv_cache\n",
    "    ):\n",
    "        residual = hidden_states\n",
    "\n",
    "        hidden_states = self.input_layernorm(hidden_states)\n",
    "\n",
    "        hidden_states = self.self_attn(\n",
    "            hidden_states=hidden_states,\n",
    "            use_kv_cache=use_kv_cache\n",
    "        )\n",
    "        \n",
    "        hidden_states = residual + hidden_states\n",
    "        residual = hidden_states\n",
    "        hidden_states = self.post_attention_layernorm(hidden_states)\n",
    "        hidden_states = self.mlp(hidden_states)\n",
    "        hidden_states = residual + hidden_states\n",
    "\n",
    "        outputs = hidden_states\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 编写自定义配置时需要记住的三个重要事项如下：\n",
    "# 1、必须继承自 PretrainedConfig\n",
    "# 2、PretrainedConfig 的 __init__ 方法必须接受任何 kwargs\n",
    "# 3、这些 kwargs 需要传递给超类的 __init__ 方法。\n",
    "class Config(PretrainedConfig):\n",
    "    model_type = \"small_model\"\n",
    "    \n",
    "    def __init__(self,\n",
    "                hidden_size = 512,\n",
    "                num_attention_heads = 16,\n",
    "                num_key_value_heads = 8,\n",
    "                flash_attn = True,\n",
    "                attention_bias = False,\n",
    "                max_seq_len = 512,\n",
    "                intermediate_size = 2048,\n",
    "                mlp_bias = False,\n",
    "                vocab_size = 6400,\n",
    "                n_layers = 8,\n",
    "                dropout = 0.0,\n",
    "                **kwargs):\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_attention_heads = num_attention_heads\n",
    "        self.num_key_value_heads = num_key_value_heads\n",
    "        self.flash_attn = flash_attn\n",
    "        self.attention_bias = attention_bias\n",
    "        self.max_seq_lenmax_seq_len = max_seq_len\n",
    "        self.intermediate_size = intermediate_size\n",
    "        self.mlp_bias = mlp_bias\n",
    "        self.vocab_size = vocab_size\n",
    "        self.n_layers = n_layers\n",
    "        self.dropout = dropout\n",
    "        super().__init__(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为了使用transformers的Trainer，需要继承PreTrainedModel\n",
    "class LLM(PreTrainedModel):\n",
    "    config_class = Config\n",
    "    \n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.config = config\n",
    "        self.vocab_size = self.config.vocab_size\n",
    "        self.n_layers = self.config.n_layers\n",
    "\n",
    "        self.tokon_embeddings = nn.Embedding(self.config.vocab_size, self.config.hidden_size)\n",
    "        self.dropout = nn.Dropout(self.config.dropout) \n",
    "        self.layers = torch.nn.ModuleList() \n",
    "        for layer_idx in range(self.n_layers):\n",
    "            self.layers.append(DecoderLayer(self.config, layer_idx)) \n",
    "        self.norm = RMSNorm(self.config.hidden_size)\n",
    "        self.output = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) \n",
    "        self.apply(self._init_weights) \n",
    "        self.loss = None \n",
    "        \n",
    "        for pn, p in self.named_parameters():\n",
    "            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):\n",
    "                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers)) \n",
    "\n",
    "    def _init_weights(self, module):\n",
    "        if isinstance(module, nn.Linear):\n",
    "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)  \n",
    "            if module.bias is not None:\n",
    "                torch.nn.init.zeros_(module.bias)  \n",
    "        elif isinstance(module, nn.Embedding):\n",
    "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) \n",
    "            \n",
    "        \n",
    "    def forward(self, input_ids, labels, use_kv_cache=False):\n",
    "       \n",
    "        hidden_states = self.tokon_embeddings(input_ids) \n",
    "        hidden_states = self.dropout(hidden_states)  \n",
    "        for idx, layer in enumerate(self.layers):\n",
    "            hidden_states = layer(hidden_states, use_kv_cache=use_kv_cache)  \n",
    "\n",
    "        hidden_states = self.norm(hidden_states) \n",
    "\n",
    "        if labels is not None:\n",
    "            logits = self.output(hidden_states)  \n",
    "            self.loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=0) \n",
    "        else:\n",
    "            logits = self.output(hidden_states[:, [-1], :])  \n",
    "            self.loss = None  \n",
    "\n",
    "        return CausalLMOutputWithPast(self.loss, logits)\n",
    "    \n",
    "    @torch.inference_mode\n",
    "    def generate(self, inputs, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.,\n",
    "                 use_kv_cache=True):\n",
    "        \n",
    "        input_ids = inputs['input_ids']\n",
    "        labels = inputs['labels']\n",
    "        s = input_ids.shape[1]\n",
    "        while input_ids.shape[1] < max_new_tokens - 1:  \n",
    "            inference_res = self(input_ids, labels, use_kv_cache=use_kv_cache)  \n",
    "            logits = inference_res.logits \n",
    "            logits = logits[:, -1, :] \n",
    "\n",
    "            for token in set(input_ids.tolist()[0]):  \n",
    "                logits[:, token] /= repetition_penalty\n",
    "\n",
    "            if temperature == 0.0: \n",
    "                _, idx_next = torch.topk(logits, k=1, dim=-1)\n",
    "            else:\n",
    "                logits = logits / temperature  \n",
    "                if top_k is not None:  \n",
    "                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
    "                    logits[logits < v[:, [-1]]] = -float('Inf') \n",
    "\n",
    "                probs = F.softmax(logits, dim=-1)  \n",
    "                idx_next = torch.multinomial(probs, num_samples=1, generator=None)  \n",
    "\n",
    "            if idx_next == eos:  \n",
    "                break\n",
    "\n",
    "            input_ids = torch.cat((input_ids, idx_next), dim=1)  \n",
    "            if stream:  \n",
    "                yield input_ids[:, s:]  \n",
    "\n",
    "        if not stream:  \n",
    "            yield input_ids[:, s:] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "config = Config()\n",
    "model = LLM(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38068736"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "count_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = torch.randint(0, 1000, (2,512))\n",
    "labels = torch.randint(0, 1000, (2,512))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 512, 6400])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(inputs, labels).logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'</s>'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(output_dir='./results', \n",
    "                         num_train_epochs=20, \n",
    "                         do_train=True, \n",
    "                         per_device_train_batch_size=1, \n",
    "                         gradient_accumulation_steps=1,\n",
    "                         group_by_length=False,\n",
    "                         max_steps=100)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LLMDataset(IterableDataset):\n",
    "    def __init__(self, data_path, tokenizer, max_seq_len):\n",
    "        super().__init__()\n",
    "        self.data_path = data_path\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "    \n",
    "    def __iter__(self):\n",
    "        return self.data_process()\n",
    "    \n",
    "    def data_process(self):\n",
    "        with open(self.data_path, 'r', encoding='utf-8') as f:\n",
    "            for line in f:\n",
    "                line = json.loads(line)\n",
    "                text = '<s>' + line['text'] + '</s>'\n",
    "                input_ids = self.tokenizer.encode(text)\n",
    "                text_len = len(input_ids)\n",
    "                if text_len > self.max_seq_len:\n",
    "                    input_ids = input_ids[:self.max_seq_len]\n",
    "                else:\n",
    "                    input_ids = input_ids + [0] * (self.max_seq_len - text_len)\n",
    "                input_ids = np.array(input_ids)\n",
    "                X = np.array(input_ids[:-1]).astype(np.int64)\n",
    "                Y = np.array(input_ids[1:]).astype(np.int64)\n",
    "                yield {\n",
    "                    'input_ids': torch.from_numpy(X),\n",
    "                    'labels': torch.from_numpy(Y),\n",
    "                }\n",
    "                \n",
    "                \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = LLMDataset('./dataset/pretrain_data.jsonl', tokenizer=tokenizer, max_seq_len=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SFTDataset(Dataset):\n",
    "    def __init__(self, data_path, tokenizer, max_seq_len):\n",
    "        super().__init__()\n",
    "        self.data_path = data_path\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "        \n",
    "        with open(self.data_path, 'r', encoding='utf-8') as f:\n",
    "            self.data = f.readlines()\n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(self.data)    \n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        line = self.data[index]\n",
    "        line = json.loads(line)\n",
    "        instruction_text = line['instruction']\n",
    "        input_text = line['input']\n",
    "        output_text = line['output']\n",
    "        history = line['history']\n",
    "        query = instruction_text + input_text\n",
    "        answer = output_text + self.tokenizer.eos_token\n",
    "        messages = []\n",
    "        if history:\n",
    "            for i in history:\n",
    "                messages.append({'role': 'user', 'content': i[0]})\n",
    "                messages.append({'role': 'assistant', 'content': i[1]})\n",
    "        \n",
    "        messages.append({'role': 'user', 'content': query})   \n",
    "        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) \n",
    "        prompt_input_ids = self.tokenizer.encode(prompt)\n",
    "        answer_input_ids = self.tokenizer.encode(answer)\n",
    "        input_ids = prompt_input_ids + answer_input_ids\n",
    "        labels = [0] * len(prompt_input_ids) + answer_input_ids\n",
    "        text_len = len(input_ids)\n",
    "        if text_len > self.max_seq_len:\n",
    "            input_ids = input_ids[:self.max_seq_len]\n",
    "            labels = labels[:self.max_seq_len]\n",
    "        else:\n",
    "            input_ids = input_ids + [0] * (self.max_seq_len - text_len)\n",
    "            labels = labels + [0] * (self.max_seq_len - text_len)\n",
    "        \n",
    "        input_ids = input_ids[:-1]\n",
    "        labels = labels[1:]\n",
    "        return {'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(labels)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[22], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mSFTDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/home/user/wyf/sft_data_zh.jsonl\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_seq_len\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2048\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[20], line 9\u001b[0m, in \u001b[0;36mSFTDataset.__init__\u001b[0;34m(self, data_path, tokenizer, max_seq_len)\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_seq_len \u001b[38;5;241m=\u001b[39m max_seq_len\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_path, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m, encoding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m----> 9\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreadlines\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m<frozen codecs>:319\u001b[0m, in \u001b[0;36mdecode\u001b[0;34m(self, input, final)\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "dataset = SFTDataset('/home/user/wyf/sft_data_zh.jsonl', tokenizer=tokenizer, max_seq_len=2048)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'PreTrainedTokenizerFast' object has no attribute 'eos'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[21], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n",
      "Cell \u001b[0;32mIn[16], line 22\u001b[0m, in \u001b[0;36mSFTDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m     20\u001b[0m history \u001b[38;5;241m=\u001b[39m line[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhistory\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m     21\u001b[0m query \u001b[38;5;241m=\u001b[39m instruction_text \u001b[38;5;241m+\u001b[39m input_text\n\u001b[0;32m---> 22\u001b[0m answer \u001b[38;5;241m=\u001b[39m output_text \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meos\u001b[49m\n\u001b[1;32m     23\u001b[0m messages \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m     24\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m history:\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'PreTrainedTokenizerFast' object has no attribute 'eos'"
     ]
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DefaultDataCollator()\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")\n",
    "args = TrainingArguments(output_dir='./results', \n",
    "                        num_train_epochs=20, \n",
    "                        do_train=True, \n",
    "                        per_device_train_batch_size=2,\n",
    "                        gradient_accumulation_steps=1,\n",
    "                        group_by_length=False,\n",
    "                        max_steps=10,\n",
    "                        logging_steps=10,\n",
    "                        report_to = 'none')            \n",
    "dataset = LLMDataset('./dataset/pretrain_data.jsonl', tokenizer=tokenizer, max_seq_len=512)\n",
    "trainer = Trainer(model=model, args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model('./saves/model')\n",
    "trainer.save_state()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wyf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
